import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import inv
from scipy.linalg import schur, sqrtm
from copy import deepcopy
from scipy.spatial.distance import cdist
import torch
import torch.nn.functional as F

##########numpy
def invSqrt(a, b, c):
    eps = 1e-12
    mask = b != 0
    r1 = mask * (c - a) / (2.0 * b + eps)
    t1 = np.sign(r1) / (np.abs(r1) + np.sqrt(1.0 + r1 * r1))
    r = 1.0 / np.sqrt(1.0 + t1 * t1)
    t = t1 * r

    r = r * mask + 1.0 * (1.0 - mask)
    t = t * mask

    x = 1.0 / np.sqrt(r * r * a - 2 * r * t * b + t * t * c)
    z = 1.0 / np.sqrt(t * t * a + 2 * r * t * b + r * r * c)

    d = np.sqrt(x * z)

    x = x / d
    z = z / d

    new_a = r * r * x + t * t * z
    new_b = -r * t * x + t * r * z
    new_c = t * t * x + r * r * z

    return new_a, new_b, new_c


def LAFs2ellT(LAFs):
    ellipses = torch.zeros((len(LAFs), 5))
    if LAFs.is_cuda:
        ellipses = ellipses.cuda()
    scale = torch.sqrt(LAFs[:, 0, 0] * LAFs[:, 1, 1] - LAFs[:, 0, 1] * LAFs[:, 1, 0] + 1e-10)  # .view(-1,1,1)
    unscaled_As = LAFs[:, 0:2, 0:2] / scale.view(-1, 1, 1).repeat(1, 2, 2)
    u, W, v = bsvd2x2(unscaled_As)
    # W = 1.0 / ((W *scale.view(-1,1,1).repeat(1,2,2))**2)
    W[:, 0, 0] = 1.0 / (scale * scale * W[:, 0, 0] ** 2)
    W[:, 1, 1] = 1.0 / (scale * scale * W[:, 1, 1] ** 2)
    A = torch.bmm(torch.bmm(u, W), u.permute(0, 2, 1))
    ellipses[:, 0] = LAFs[:, 0, 2]
    ellipses[:, 1] = LAFs[:, 1, 2]
    ellipses[:, 2] = A[:, 0, 0]
    ellipses[:, 3] = A[:, 0, 1]
    ellipses[:, 4] = A[:, 1, 1]
    return ellipses

def LAF2A(laf, h, w, PS=32):
    min_size = float(min(h, w))
    A = deepcopy(laf)
    A[0, :2] = 2.0 * laf[0, :2] / float(PS)
    A[1, :2] = 2.0 * laf[1, :2] / float(PS)
    s = float(PS) * np.sqrt(np.abs(A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0]))
    A[0, 2] = A[0, 2] - s / 2.0
    A[1, 2] = A[1, 2] - s / 2.0
    return A

def invSqrtTorch(a, b, c):
    eps = 1e-12
    mask = (b != 0).float()
    r1 = mask * (c - a) / (2.0 * b + eps)
    t1 = torch.sign(r1) / (torch.abs(r1) + torch.sqrt(1.0 + r1 * r1))
    r = 1.0 / torch.sqrt(1.0 + t1 * t1)
    t = t1 * r
    r = r * mask + 1.0 * (1.0 - mask)
    t = t * mask

    x = 1.0 / torch.sqrt(r * r * a - 2.0 * r * t * b + t * t * c)
    z = 1.0 / torch.sqrt(t * t * a + 2.0 * r * t * b + r * r * c)

    d = torch.sqrt(x * z)

    x = x / d
    z = z / d

    new_a = r * r * x + t * t * z
    new_b = -r * t * x + t * r * z
    new_c = t * t * x + r * r * z

    return new_a, new_b, new_c


def ells2LAFsT(ells):
    LAFs = torch.zeros((len(ells), 2, 3))
    LAFs[:, 0, 2] = ells[:, 0]
    LAFs[:, 1, 2] = ells[:, 1]
    a = ells[:, 2]
    b = ells[:, 3]
    c = ells[:, 4]
    sc = torch.sqrt(torch.sqrt(a * c - b * b + 1e-12))
    ia, ib, ic = invSqrtTorch(a, b, c)  # because sqrtm returns ::-1, ::-1 matrix, don`t know why
    A = torch.cat([torch.cat([(ia / sc).view(-1, 1, 1), (ib / sc).view(-1, 1, 1)], dim=2), torch.cat([(ib / sc).view(-1, 1, 1), (ic / sc).view(-1, 1, 1)], dim=2)], dim=1)
    sc = torch.sqrt(torch.abs(A[:, 0, 0] * A[:, 1, 1] - A[:, 1, 0] * A[:, 0, 1]))
    LAFs[:, 0:2, 0:2] = rectifyAffineTransformationUpIsUp(A / sc.view(-1, 1, 1).repeat(1, 2, 2)) * sc.view(-1, 1, 1).repeat(1, 2, 2)
    return LAFs


def LAFs_to_H_frames(aff_pts):
    H3_x = torch.Tensor([0, 0, 1]).unsqueeze(0).unsqueeze(0).repeat(aff_pts.size(0), 1, 1)
    if aff_pts.is_cuda:
        H3_x = H3_x.cuda()
    return torch.cat([aff_pts, H3_x], dim=1)


def checkTouchBoundary(LAFs):
    pts = torch.FloatTensor([[-1, -1, 1, 1], [-1, 1, -1, 1], [1, 1, 1, 1]]).unsqueeze(0)
    if LAFs.is_cuda:
        pts = pts.cuda()
    out_pts = torch.bmm(LAFs_to_H_frames(LAFs), pts.expand(LAFs.size(0), 3, 4))[:, :2, :]
    good_points = 1 - (((out_pts > 1.0) + (out_pts < 0.0)).sum(dim=1).sum(dim=1) > 0)
    return good_points


def bsvd2x2(As):
    Su = torch.bmm(As, As.permute(0, 2, 1))
    phi = 0.5 * torch.atan2(Su[:, 0, 1] + Su[:, 1, 0] + 1e-12, Su[:, 0, 0] - Su[:, 1, 1] + 1e-12)
    Cphi = torch.cos(phi)
    Sphi = torch.sin(phi)
    U = torch.zeros(As.size(0), 2, 2)
    if As.is_cuda:
        U = U.cuda()
    U[:, 0, 0] = Cphi
    U[:, 1, 1] = Cphi
    U[:, 0, 1] = -Sphi
    U[:, 1, 0] = Sphi
    Sw = torch.bmm(As.permute(0, 2, 1), As)
    theta = 0.5 * torch.atan2(Sw[:, 0, 1] + Sw[:, 1, 0] + 1e-12, Sw[:, 0, 0] - Sw[:, 1, 1] + 1e-12)
    Ctheta = torch.cos(theta)
    Stheta = torch.sin(theta)
    W = torch.zeros(As.size(0), 2, 2)
    if As.is_cuda:
        W = W.cuda()
    W[:, 0, 0] = Ctheta
    W[:, 1, 1] = Ctheta
    W[:, 0, 1] = -Stheta
    W[:, 1, 0] = Stheta
    SUsum = Su[:, 0, 0] + Su[:, 1, 1]
    SUdif = torch.sqrt((Su[:, 0, 0] - Su[:, 1, 1]) ** 2 + 4 * Su[:, 0, 1] * Su[:, 1, 0] + 1e-12)
    if As.is_cuda:
        SIG = torch.zeros(As.size(0), 2, 2).cuda()
        SIG[:, 0, 0] = torch.sqrt((SUsum + SUdif) / 2.0)
        SIG[:, 1, 1] = torch.sqrt((SUsum - SUdif) / 2.0)
    else:
        SIG = torch.zeros(As.size(0), 2, 2)
        SIG[:, 0, 0] = torch.sqrt((SUsum + SUdif) / 2.0)
        SIG[:, 1, 1] = torch.sqrt((SUsum - SUdif) / 2.0)
    S = torch.bmm(torch.bmm(U.permute(0, 2, 1), As), W)
    C = torch.sign(S)
    C[:, 0, 1] = 0
    C[:, 1, 0] = 0
    V = torch.bmm(W, C)
    return (U, SIG, V)


def getLAFelongation(LAFs):
    u, s, v = bsvd2x2(LAFs[:, :2, :2])
    return torch.max(s[:, 0, 0], s[:, 1, 1]) / torch.min(s[:, 0, 0], s[:, 1, 1])


def getNumCollapsed(LAFs, th=10.0):
    el = getLAFelongation(LAFs)
    return (el > th).float().sum()


def Ell2LAF(ell):
    A23 = np.zeros((2, 3))
    A23[0, 2] = ell[0]
    A23[1, 2] = ell[1]
    a = ell[2]
    b = ell[3]
    c = ell[4]
    sc = np.sqrt(np.sqrt(a * c - b * b))
    ia, ib, ic = invSqrt(a, b, c)  # because sqrtm returns ::-1, ::-1 matrix, don`t know why
    A = np.array([[ia, ib], [ib, ic]]) / sc
    sc = np.sqrt(A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1])
    A23[0:2, 0:2] = rectifyAffineTransformationUpIsUp(A / sc) * sc
    return A23


def rectifyAffineTransformationUpIsUp_np(A):
    det = np.sqrt(np.abs(A[0, 0] * A[1, 1] - A[1, 0] * A[0, 1] + 1e-10))
    b2a2 = np.sqrt(A[0, 1] * A[0, 1] + A[0, 0] * A[0, 0])
    A_new = np.zeros((2, 2))
    A_new[0, 0] = b2a2 / det
    A_new[0, 1] = 0
    A_new[1, 0] = (A[1, 1] * A[0, 1] + A[1, 0] * A[0, 0]) / (b2a2 * det)
    A_new[1, 1] = det / b2a2
    return A_new


def ells2LAFs(ells):
    LAFs = np.zeros((len(ells), 2, 3))
    for i in range(len(ells)):
        LAFs[i, :, :] = Ell2LAF(ells[i, :])
    return LAFs


def LAF2pts(LAF, n_pts=50):
    a = np.linspace(0, 2 * np.pi, n_pts)
    x = [0]
    x.extend(list(np.sin(a)))
    x = np.array(x).reshape(1, -1)
    y = [0]
    y.extend(list(np.cos(a)))
    y = np.array(y).reshape(1, -1)
    HLAF = np.concatenate([LAF, np.array([0, 0, 1]).reshape(1, 3)])
    H_pts = np.concatenate([x, y, np.ones(x.shape)])
    H_pts_out = np.transpose(np.matmul(HLAF, H_pts))
    H_pts_out[:, 0] = H_pts_out[:, 0] / H_pts_out[:, 2]
    H_pts_out[:, 1] = H_pts_out[:, 1] / H_pts_out[:, 2]
    return H_pts_out[:, 0:2]


def convertLAFs_to_A23format(LAFs):
    sh = LAFs.shape
    if (len(sh) == 3) and (sh[1] == 2) and (sh[2] == 3):  # n x 2 x 3 classical [A, (x;y)] matrix
        work_LAFs = deepcopy(LAFs)
    elif (len(sh) == 2) and (sh[1] == 7):  # flat format, x y scale a11 a12 a21 a22
        work_LAFs = np.zeros((sh[0], 2, 3))
        work_LAFs[:, 0, 2] = LAFs[:, 0]
        work_LAFs[:, 1, 2] = LAFs[:, 1]
        work_LAFs[:, 0, 0] = LAFs[:, 2] * LAFs[:, 3]
        work_LAFs[:, 0, 1] = LAFs[:, 2] * LAFs[:, 4]
        work_LAFs[:, 1, 0] = LAFs[:, 2] * LAFs[:, 5]
        work_LAFs[:, 1, 1] = LAFs[:, 2] * LAFs[:, 6]
    elif (len(sh) == 2) and (sh[1] == 6):  # flat format, x y s*a11 s*a12 s*a21 s*a22
        work_LAFs = np.zeros((sh[0], 2, 3))
        work_LAFs[:, 0, 2] = LAFs[:, 0]
        work_LAFs[:, 1, 2] = LAFs[:, 1]
        work_LAFs[:, 0, 0] = LAFs[:, 2]
        work_LAFs[:, 0, 1] = LAFs[:, 3]
        work_LAFs[:, 1, 0] = LAFs[:, 4]
        work_LAFs[:, 1, 1] = LAFs[:, 5]
    else:
        print("Unknown LAF format")
        return None
    return work_LAFs


def LAFs2ell(in_LAFs):
    LAFs = convertLAFs_to_A23format(in_LAFs)
    ellipses = np.zeros((len(LAFs), 5))
    for i in range(len(LAFs)):
        LAF = deepcopy(LAFs[i, :, :])
        scale = np.sqrt(LAF[0, 0] * LAF[1, 1] - LAF[0, 1] * LAF[1, 0] + 1e-10)
        u, W, v = np.linalg.svd(LAF[0:2, 0:2] / scale, full_matrices=True)
        W[0] = 1.0 / (W[0] * W[0] * scale * scale)
        W[1] = 1.0 / (W[1] * W[1] * scale * scale)
        A = np.matmul(np.matmul(u, np.diag(W)), u.transpose())
        ellipses[i, 0] = LAF[0, 2]
        ellipses[i, 1] = LAF[1, 2]
        ellipses[i, 2] = A[0, 0]
        ellipses[i, 3] = A[0, 1]
        ellipses[i, 4] = A[1, 1]
    return ellipses


def visualize_LAFs(img, LAFs, color="r", show=False, save_to=None):
    work_LAFs = convertLAFs_to_A23format(LAFs)
    plt.figure()
    plt.imshow(255 - img)
    for i in range(len(work_LAFs)):
        ell = LAF2pts(work_LAFs[i, :, :])
        plt.plot(ell[:, 0], ell[:, 1], color)
    if show:
        plt.show()
    if save_to is not None:
        plt.savefig(save_to)
    return


####pytorch


def get_normalized_affine_shape(tilt, angle_in_radians):
    assert tilt.size(0) == angle_in_radians.size(0)
    num = tilt.size(0)
    tilt_A = torch.eye(2).view(1, 2, 2).repeat(num, 1, 1)
    if tilt.is_cuda:
        tilt_A = tilt_A.cuda()
    tilt_A[:, 0, 0] = tilt.view(-1)
    rotmat = get_rotation_matrix(angle_in_radians)
    out_A = rectifyAffineTransformationUpIsUp(torch.bmm(rotmat, torch.bmm(tilt_A, rotmat)))
    # re_scale = (1.0/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small
    # re_scale = (0.5 + 0.5/torch.sqrt((out_A **2).sum(dim=1).max(dim=1)[0])) #It is heuristic to for keeping scale change small
    return out_A  # * re_scale.view(-1,1,1).expand(num,2,2)


def get_rotation_matrix(angle_in_radians):
    angle_in_radians = angle_in_radians.view(-1, 1, 1)
    sin_a = torch.sin(angle_in_radians)
    cos_a = torch.cos(angle_in_radians)
    A1_x = torch.cat([cos_a, sin_a], dim=2)
    A2_x = torch.cat([-sin_a, cos_a], dim=2)
    transform = torch.cat([A1_x, A2_x], dim=1)
    return transform


def rectifyAffineTransformationUpIsUp(A):
    det = torch.sqrt(torch.abs(A[:, 0, 0] * A[:, 1, 1] - A[:, 1, 0] * A[:, 0, 1] + 1e-10))
    b2a2 = torch.sqrt(A[:, 0, 1] * A[:, 0, 1] + A[:, 0, 0] * A[:, 0, 0])
    A1_ell = torch.cat([(b2a2 / det).contiguous().view(-1, 1, 1), 0 * det.view(-1, 1, 1)], dim=2)
    A2_ell = torch.cat([((A[:, 1, 1] * A[:, 0, 1] + A[:, 1, 0] * A[:, 0, 0]) / (b2a2 * det)).contiguous().view(-1, 1, 1), (det / b2a2).contiguous().view(-1, 1, 1)], dim=2)
    return torch.cat([A1_ell, A2_ell], dim=1)


def rectifyAffineTransformationUpIsUpFullyConv(A):  # A is (n,4,h,w) tensor
    det = torch.sqrt(torch.abs(A[:, 0:1, :, :] * A[:, 3:4, :, :] - A[:, 1:2, :, :] * A[:, 2:3, :, :] + 1e-10))
    b2a2 = torch.sqrt(A[:, 1:2, :, :] * A[:, 1:2, :, :] + A[:, 0:1, :, :] * A[:, 0:1, :, :])
    return torch.cat([(b2a2 / det).contiguous(), 0 * det.contiguous(), (A[:, 3:4, :, :] * A[:, 1:2, :, :] + A[:, 2:3, :, :] * A[:, 0:1, :, :]) / (b2a2 * det), (det / b2a2).contiguous()], dim=1)


def abc2A(a, b, c, normalize=False):
    A1_ell = torch.cat([a.view(-1, 1, 1), b.view(-1, 1, 1)], dim=2)
    A2_ell = torch.cat([b.view(-1, 1, 1), c.view(-1, 1, 1)], dim=2)
    return torch.cat([A1_ell, A2_ell], dim=1)


def angles2A(angles):
    cos_a = torch.cos(angles).view(-1, 1, 1)
    sin_a = torch.sin(angles).view(-1, 1, 1)
    A1_ang = torch.cat([cos_a, sin_a], dim=2)
    A2_ang = torch.cat([-sin_a, cos_a], dim=2)
    return torch.cat([A1_ang, A2_ang], dim=1)


def generate_patch_grid_from_normalized_LAFs(LAFs, w, h, PS):
    num_lafs = LAFs.size(0)
    min_size = min(h, w)
    coef = torch.ones(1, 2, 3) * min_size
    coef[0, 0, 2] = w
    coef[0, 1, 2] = h
    if LAFs.is_cuda:
        coef = coef.cuda()
    grid = F.affine_grid(LAFs * coef.expand(num_lafs, 2, 3), torch.Size((num_lafs, 1, PS, PS)))
    grid[:, :, :, 0] = 2.0 * grid[:, :, :, 0] / float(w) - 1.0
    grid[:, :, :, 1] = 2.0 * grid[:, :, :, 1] / float(h) - 1.0
    return grid


def batched_grid_apply(img, grid, batch_size=32):
    n_patches = len(grid)
    if n_patches > batch_size:
        bs = batch_size
        n_batches = n_patches / bs + 1
        for batch_idx in range(n_batches):
            st = batch_idx * bs
            if batch_idx == n_batches - 1:
                if (batch_idx + 1) * bs > n_patches:
                    end = n_patches
                else:
                    end = (batch_idx + 1) * bs
            else:
                end = (batch_idx + 1) * bs
            if st >= end:
                continue
            if batch_idx == 0:
                if img.size(0) != grid.size(0):
                    first_batch_out = F.grid_sample(img.expand(end - st, img.size(1), img.size(2), img.size(3)), grid[st:end, :, :, :])  # kwargs)
                else:
                    first_batch_out = F.grid_sample(img[st:end], grid[st:end, :, :, :])  # kwargs)
                out_size = torch.Size([n_patches] + list(first_batch_out.size()[1:]))
                out = torch.zeros(out_size)
                if img.is_cuda:
                    out = out.cuda()
                out[st:end] = first_batch_out
            else:
                if img.size(0) != grid.size(0):
                    out[st:end, :, :] = F.grid_sample(img.expand(end - st, img.size(1), img.size(2), img.size(3)), grid[st:end, :, :, :])
                else:
                    out[st:end, :, :] = F.grid_sample(img[st:end], grid[st:end, :, :, :])
        return out
    else:
        if img.size(0) != grid.size(0):
            return F.grid_sample(img.expand(grid.size(0), img.size(1), img.size(2), img.size(3)), grid)
        else:
            return F.grid_sample(img, grid)


def extract_patches(img, LAFs, PS=32, bs=32):
    w = img.size(3)
    h = img.size(2)
    ch = img.size(1)
    grid = generate_patch_grid_from_normalized_LAFs(LAFs, float(w), float(h), PS)
    if bs is None:
        return torch.nn.functional.grid_sample(img.expand(grid.size(0), ch, h, w), grid)
    else:
        return batched_grid_apply(img, grid, bs)


def get_pyramid_inverted_index_for_LAFs(LAFs, PS, sigmas):
    return


def extract_patches_from_pyramid_with_inv_index(scale_pyramid, pyr_inv_idxs, LAFs, PS=19):
    patches = torch.zeros(LAFs.size(0), scale_pyramid[0][0].size(1), PS, PS)
    if LAFs.is_cuda:
        patches = patches.cuda()
    if pyr_inv_idxs is not None:
        for i in range(len(scale_pyramid)):
            for j in range(len(scale_pyramid[i])):
                cur_lvl_idxs = pyr_inv_idxs[i][j]
                if cur_lvl_idxs is None:
                    continue
                cur_lvl_idxs = cur_lvl_idxs.view(-1)
                # print i,j,cur_lvl_idxs.shape
                patches[cur_lvl_idxs, :, :, :] = extract_patches(scale_pyramid[i][j], LAFs[cur_lvl_idxs, :, :], PS, 32)
    return patches


def get_inverted_pyr_index(scale_pyr, pyr_idxs, level_idxs):
    pyr_inv_idxs = []
    ### Precompute octave inverted indexes
    for i in range(len(scale_pyr)):
        pyr_inv_idxs.append([])
        cur_idxs = pyr_idxs == i  # torch.nonzero((pyr_idxs == i).data)
        for j in range(0, len(scale_pyr[i])):
            cur_lvl_idxs = torch.nonzero(((level_idxs == j) * cur_idxs).data)
            if cur_lvl_idxs.size(0) == 0:
                pyr_inv_idxs[i].append(None)
            else:
                pyr_inv_idxs[i].append(cur_lvl_idxs.squeeze())
    return pyr_inv_idxs


def denormalizeLAFs(LAFs, w, h):
    w = float(w)
    h = float(h)
    num_lafs = LAFs.size(0)
    min_size = min(h, w)
    coef = torch.ones(1, 2, 3).float() * min_size
    coef[0, 0, 2] = w
    coef[0, 1, 2] = h
    if LAFs.is_cuda:
        coef = coef.cuda()
    return coef.expand(num_lafs, 2, 3) * LAFs


def normalizeLAFs(LAFs, w, h):
    w = float(w)
    h = float(h)
    num_lafs = LAFs.size(0)
    min_size = min(h, w)
    coef = torch.ones(1, 2, 3).float() / min_size
    coef[0, 0, 2] = 1.0 / w
    coef[0, 1, 2] = 1.0 / h
    if LAFs.is_cuda:
        coef = coef.cuda()
    return coef.expand(num_lafs, 2, 3) * LAFs


def sc_y_x2LAFs(sc_y_x):
    base_LAF = torch.eye(2).float().unsqueeze(0).expand(sc_y_x.size(0), 2, 2)
    if sc_y_x.is_cuda:
        base_LAF = base_LAF.cuda()
    A = sc_y_x[:, :1].unsqueeze(1).expand_as(base_A) * base_A
    LAFs = torch.cat([A, torch.cat([sc_y_x[:, 2:].unsqueeze(-1), sc_y_x[:, 1:2].unsqueeze(-1)], dim=1)], dim=2)

    return LAFs


def sc_y_x_and_A2LAFs(sc_y_x, A_flat):
    base_A = A_flat.view(-1, 2, 2)
    A = sc_y_x[:, :1].unsqueeze(1).expand_as(base_A) * base_A
    LAFs = torch.cat([A, torch.cat([sc_y_x[:, 2:].unsqueeze(-1), sc_y_x[:, 1:2].unsqueeze(-1)], dim=1)], dim=2)

    return LAFs


def get_LAFs_scales(LAFs):
    return torch.sqrt(torch.abs(LAFs[:, 0, 0] * LAFs[:, 1, 1] - LAFs[:, 0, 1] * LAFs[:, 1, 0]) + 1e-12)


def get_pyramid_and_level_index_for_LAFs(dLAFs, sigmas, pix_dists, PS):
    scales = get_LAFs_scales(dLAFs)
    needed_sigmas = scales / PS
    sigmas_full_list = []
    level_idxs_full = []
    oct_idxs_full = []
    for oct_idx in range(len(sigmas)):
        sigmas_full_list = sigmas_full_list + list(np.array(sigmas[oct_idx]) * np.array(pix_dists[oct_idx]))
        oct_idxs_full = oct_idxs_full + [oct_idx] * len(sigmas[oct_idx])
        level_idxs_full = level_idxs_full + range(0, len(sigmas[oct_idx]))
    oct_idxs_full = torch.LongTensor(oct_idxs_full)
    level_idxs_full = torch.LongTensor(level_idxs_full)

    closest_imgs = cdist(np.array(sigmas_full_list).reshape(-1, 1), needed_sigmas.data.cpu().numpy().reshape(-1, 1)).argmin(axis=0)
    closest_imgs = torch.from_numpy(closest_imgs)
    if dLAFs.is_cuda:
        closest_imgs = closest_imgs.cuda()
        oct_idxs_full = oct_idxs_full.cuda()
        level_idxs_full = level_idxs_full.cuda()
    return oct_idxs_full[closest_imgs], level_idxs_full[closest_imgs]
